library(tidyverse)
library(haven)
library(texreg)
library(broom)
library(prediction)
library(dotwhisker)

outputs_path <- "../../outputs/replication/"

source("functions_panel_models.R")

source("texreg_variable_list.R")

ex <- read_rds("../../data/replication/experiment_data_tidied.rds")

controls_vec <- c("treatment_w21",
                  "male_w21",
                  "white_british_w21", 
                  "age_w21",
                  "class_w21",
                  "education_w21")

# set model vars alongside neat names to avoid mistaking models
dvars <- tibble(variables = "uk_covid_performance_w21_num",
                neat_names = "Overall") %>%
  add_row(variables = "covid_overall_gov_fault_w21",
          neat_names = "Overall resp") %>%
  add_row(variables = "death_toll_gov_fault_w21",
          neat_names = "Death toll") %>%
  add_row(variables = "vaccine_gov_fault_w21",
          neat_names = "Vaccine") %>%
  add_row(variables = "retro_handle_w21_num",
          neat_names = "Retro handle") %>%
  add_row(variables = "retro_handle_change_w21",
          neat_names = "Change in retro handle")

# define model you want to present here. done as string to make sure right
# margins are chosen too
presented_model_name <- "robust_party"

# now all the other models that could be chosen, and which can be added to the 
# rope ladder plot

# note can't use prediction if data is given as a function, so need to 
# set the data beforehand. Not ideal but it's fine for this
model_data <- ex
base_models <- map(dvars$variables, 
                   weights = "wt_new_W21",
                   extra_variable = "party_id_gov_op_w19",
                   ols_gov_op)
ols_gov_op <- function(dependent_variable,
                       extra_variable,
                       weights){

  with_dems_form <- formula(paste0(dependent_variable,
                                   " ~ ",
                                   extra_variable,
                                   " + ",
                                   paste0(controls_vec, collapse = " + ")))
  model <- lm(with_dems_form,
              data = model_data,
              weights = model_data[[weights]])
  return(model)
}
with_dems_form <- formula(paste0("uk_covid_performance_w21_num",
                                 " ~ ",
                                 "party_id_gov_op_w19",
                                 " + ",
                                 paste0(controls_vec, collapse = " + ")))
ols_gov_op("uk_covid_performance_w21_num",
           weights = "wt_new_W21",
           extra_variable = "party_id_gov_op_w19")
base_predictions <- map2(base_models, 
                         dvars$variables,
                         get_gov_op_predictions
                         ) %>%
  bind_rows() %>%
  as_tibble() %>%
  mutate(spec = "Base model") %>%
  tidy_predictions()

model_data <- ex
pol_att_weighting <- map(dvars$variables, 
                         weights = "wt_with_pol_att_w21",
                         extra_variable = "party_id_gov_op_w19",
                         ols_gov_op)

pol_att_weighting_predictions <- map2(pol_att_weighting,
                                      dvars$variables,
                                      get_gov_op_predictions) %>%
  bind_rows() %>%
  as_tibble() %>%
  mutate(spec = "+ weights for pol. attention") %>%
  tidy_predictions()

model_data <- ex %>% filter(rm_for_vaccine_attitude != 1)
robust_vaccine <- map(dvars$variables, 
                      weights = "wt_with_pol_att_w21",
                      extra_variable = "party_id_gov_op_w19",
                      ols_gov_op)
robust_vaccine_predictions <- map2(robust_vaccine,
                                   dvars$variables,
                                   get_gov_op_predictions) %>%
  bind_rows() %>%
  as_tibble() %>%
  mutate(spec = "+ remove vaccine hesitant") %>%
  tidy_predictions()

model_data <- ex %>% filter(rm_for_vaccine_attitude != 1) %>%
  filter(rm_for_party != 1) 
robust_party <- map(dvars$variables, 
                    weights = "wt_with_pol_att_w21",
                    extra_variable = "party_id_gov_op_w19",
                    ols_gov_op)
robust_party_predictions <- map2(robust_party,
                                 dvars$variables,
                                 get_gov_op_predictions) %>%
  bind_rows() %>%
  as_tibble() %>%
  mutate(spec = "+ remove by party") %>%
  tidy_predictions()

model_data <- ex %>% filter(rm_for_vaccine_attitude != 1) %>%
  filter(partyIdW20 %in% c(1, 2, 10))
just_con_lab <- map(dvars$variables, 
                    weights = "wt_with_pol_att_w21",
                    extra_variable = "party_id_gov_op_w19",
                    ols_gov_op)
just_con_lab_predictions <- map2(just_con_lab,
                                 dvars$variables,
                                 get_gov_op_predictions) %>%
  bind_rows() %>%
  as_tibble() %>%
  mutate(spec = "+ just Con / Lab / None") %>%
  tidy_predictions

all_predictions <- bind_rows(
  base_predictions,
  pol_att_weighting_predictions,
  robust_vaccine_predictions,
  robust_party_predictions,
  just_con_lab_predictions) %>%
  mutate(facet = factor(term, levels = c("Government partisan",
                                         "Opposition partisan"))) %>%
  mutate(term = spec) %>%
  arrange(facet)

# defining function in file since might want different coef maps or 
# scale boxes for different models
custom_texreg <- function(list_of_models, ...){
  customised_reg <- generic_tidy_texreg(list_of_models,
    sideways = FALSE,
    custom.note = ("\\normalsize{%stars}"),
    #custom.coef.map = texreg_variable_list,
    scalebox = 0.6,
    ...)

  return(customised_reg)
}

# get the right models based on choice defined at top of document
presented_model <- get(presented_model_name)

caption_suffix <- paste("Results from OLS regressions of all dependent",
                        "variables on party ID")

custom_texreg(presented_model,
       custom.header = list("\\textit{Dependent variable:}" = 
                            1:(length(dvars$neat_names))),
       custom.model.names = dvars$neat_names,
       caption = paste0(caption_suffix, ", with full controls added"),
       custom.coef.map = texreg_variable_list,
       label = "tab: gov_op_intx_controls",
       groups = list("Partisanship: (ref = Opposition)" = 1:3,
                     "Treatment: (ref = Control)" = 4:5,
                     "Age: (ref = Under 18)" = 6:10,
                     "Class: (ref = Higher Managerial)" = 11:15,
                     "Education: (ref = Undergrad or Higher)" = 16:19),
       file = paste0(outputs_path, "panel_with_controls.tex"))

# signif done by eye ie not overlapping confidence intervals for any except
# change in retro handling where they clearly overlap almost fully
presented_model_predictions <- get(paste0(presented_model_name, "_predictions")) %>%
  mutate(signif_diff = case_when(
   model == "retro_handle_change_w21" ~ "No Signif. Difference",
   TRUE ~ "Signif. Difference")) 

all_plot <- dwplot(presented_model_predictions,
                   dodge_size = 0.80,
                   dot_args = list(aes(shape = signif_diff))) + 
  theme_bw() +
  theme(legend.position = "bottom") +
  xlab("Average predicted position") + 
  guides(color = "none",
         shape = "none") +
  scale_colour_manual(values = rep("#000000", 6)) +
  labs(shape = "Partisanship in 2019") +
  facet_wrap(~neat_names, nrow = 3, scales = "free_x") 

ggsave(all_plot,
       height = 7.2,
       width = 6,
       units = "in",
       file = paste0(outputs_path, "panel_all_plotted", ".pdf"))

ggsave(all_plot,
       height = 7.2,
       width = 6,
       units = "in",
       file = paste0(outputs_path, "panel_all_plotted", ".png"))

# Robustness plot
robust_predictions <- all_predictions %>%
  rename(temp = facet) %>%
  rename(facet = model) %>%
  rename(model = temp) %>%
  filter(model != "Non-partisan")

robustness_plot <- dwplot(robust_predictions,
                             dodge_size = 0.80,
                             dot_args = list(aes(colour = model))) + 
  theme_bw() +
  scale_colour_manual(values = c("#E4003B", "#0087DC"), 
                      name = "Partisanship") +
  theme(legend.position = "bottom") +
  xlab("Average predicted position") + 
  facet_wrap(~neat_names, nrow = 3, scale = "free_x") 

rob_rect <- robust_predictions %>% group_by(model, neat_names) %>%
    mutate(min_rect = min(estimate),
           max_rect = max(estimate)) 

final_robust_plot <- robustness_plot +
  geom_rect(data = rob_rect %>% filter(model == "Government partisan"), 
            aes(xmin=min_rect, xmax=max_rect, ymin=-Inf, ymax=Inf), 
            alpha = 0.05, fill = '#0087DC') +
  geom_rect(data = rob_rect %>% filter(model == "Opposition partisan"), 
            aes(xmin=min_rect, xmax=max_rect, ymin=-Inf, ymax=Inf), 
            alpha = 0.05, fill = "#E4003B") 

ggsave(final_robust_plot,
       height = 7.2,
       width = 6,
       units = "in",
       file = paste0(outputs_path, "panel_robust", ".pdf"))

